{ "cells": [ { "cell_type": "markdown", "id": "e909f87b-0738-4aea-9388-6ffe258db9f2", "metadata": {}, "source": [ "# Deploy Stable Cascade for Real-Time Image Generation on SageMaker\n", "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.\n", "\n", "![ This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "---\n", "\n", "For further reading & reference materials:\n", "\n", "Sources: https://www.philschmid.de/sagemaker-stable-diffusion#2-create-sagemaker-modeltargz-artifact\n", "\n", "Further Reading: https://huggingface.co/stabilityai/stable-cascade" ] }, { "cell_type": "code", "execution_count": null, "id": "9b7a9e0a-6325-4639-85c1-ca3b4083feb9", "metadata": {}, "outputs": [], "source": [ "!pip install 'sagemaker<3.0' huggingface_hub diffusers transformers accelerate safetensors tokenizers torch --upgrade --q\n", "!pip install python-dotenv --upgrade --q" ] }, { "cell_type": "code", "execution_count": null, "id": "30738e7b-cb07-4366-84af-026bee55bdb7", "metadata": {}, "outputs": [], "source": [ "import base64\n", "import boto3\n", "import json\n", "import matplotlib.pyplot as plt\n", "import os\n", "import random\n", "import sagemaker\n", "import tarfile\n", "import time\n", "import torch\n", "\n", "from diffusers import (\n", " StableCascadePriorPipeline,\n", " StableCascadeDecoderPipeline,\n", " StableCascadeUNet,\n", ")\n", "from distutils.dir_util import copy_tree\n", "from dotenv import load_dotenv\n", "from huggingface_hub import snapshot_download\n", "from io import BytesIO\n", "from pathlib import Path\n", "from IPython.display import display\n", "from PIL import Image\n", "from sagemaker import get_execution_role\n", "from sagemaker.s3 import S3Uploader, S3Downloader\n", "from sagemaker.huggingface.model import HuggingFaceModel\n", "from sagemaker.async_inference import AsyncInferenceConfig\n", "\n", "load_dotenv()\n", "\n", "sess = sagemaker.Session()\n", "print(f\"Sagemaker bucket: {sess.default_bucket()}\")\n", "print(f\"Sagemaker session region: {sess.boto_region_name}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "18f7eef1-b2cd-4137-ba57-974f025dc1dd", "metadata": { "scrolled": true }, "outputs": [], "source": [ "HF_PRIOR_ID = \"stabilityai/stable-cascade-prior\"\n", "HF_DECODER_ID = \"stabilityai/stable-cascade\"\n", "CACHE_DIR = os.getenv(\"CACHE_DIR\", \"cache_dir\")\n", "\n", "prior_unet = StableCascadeUNet.from_pretrained(HF_PRIOR_ID, subfolder=\"prior_lite\")\n", "decoder_unet = StableCascadeUNet.from_pretrained(HF_DECODER_ID, subfolder=\"decoder_lite\")\n", "\n", "prior = StableCascadePriorPipeline.from_pretrained(\n", " HF_PRIOR_ID,\n", " variant=\"bf16\",\n", " torch_dtype=torch.bfloat16,\n", " cache_dir=CACHE_DIR,\n", " prior=prior_unet,\n", ")\n", "decoder = StableCascadeDecoderPipeline.from_pretrained(\n", " HF_DECODER_ID,\n", " variant=\"bf16\",\n", " torch_dtype=torch.bfloat16,\n", " cache_dir=CACHE_DIR,\n", " decoder=decoder_unet,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "1b15454b-7449-4d06-b10a-276181734a89", "metadata": {}, "outputs": [], "source": [ "model_path = \"model/\"\n", "prior_path = \"model/prior/\"\n", "decoder_path = \"model/decoder/\"\n", "code_path = \"code/\"\n", "cache_dir = \"cache_dir/\"\n", "\n", "if not os.path.exists(model_path):\n", " os.mkdir(model_path)\n", "if not os.path.exists(code_path):\n", " os.mkdir(code_path)\n", "if not os.path.exists(cache_dir):\n", " os.mkdir(cache_dir)\n", "if not os.path.exists(prior_path):\n", " os.mkdir(prior_path)\n", "if not os.path.exists(decoder_path):\n", " os.mkdir(decoder_path)\n", "\n", "prior.save_pretrained(save_directory=prior_path)\n", "decoder.save_pretrained(save_directory=decoder_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "16ff336e-7aee-40a5-a1e2-3e57ac542511", "metadata": {}, "outputs": [], "source": [ "# Perform local inference in notebook to validate model loading and inference call\n", "\n", "prior = StableCascadePriorPipeline.from_pretrained(prior_path, local_files_only=True)\n", "decoder = StableCascadeDecoderPipeline.from_pretrained(decoder_path, local_files_only=True)\n", "prompt = \"an image of a shiba inu, donning a spacesuit and helmet\"\n", "negative_prompt = \"\"\n", "\n", "# Uncomment to run on GPU\n", "# prior.enable_model_cpu_offload()\n", "prior_output = prior(\n", " prompt=prompt,\n", " height=1024,\n", " width=1024,\n", " negative_prompt=negative_prompt,\n", " guidance_scale=4.0,\n", " num_images_per_prompt=1,\n", " num_inference_steps=20,\n", ")\n", "\n", "# Uncomment to run on GPU\n", "# decoder.enable_model_cpu_offload()\n", "decoder_output = decoder(\n", " image_embeddings=prior_output.image_embeddings,\n", " prompt=prompt,\n", " negative_prompt=negative_prompt,\n", " guidance_scale=0.0,\n", " output_type=\"pil\",\n", " num_inference_steps=10,\n", ").images[0]\n", "decoder_output.save(\"cascade.png\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d76117ef-f131-48bf-bc83-c2f2a6652a2c", "metadata": {}, "outputs": [], "source": [ "%%writefile code/requirements.txt\n", "--find-links https://download.pytorch.org/whl/torch_stable.html\n", "accelerate>=0.25.0\n", "torch==2.1.2+cu118\n", "torchvision==0.16.2+cu118\n", "transformers>=4.30.0\n", "numpy>=1.23.5\n", "kornia>=0.7.0\n", "insightface>=0.7.3\n", "opencv-python>=4.8.1.78\n", "tqdm>=4.66.1\n", "matplotlib>=3.7.4\n", "webdataset>=0.2.79\n", "wandb>=0.16.2\n", "munch>=4.0.0\n", "onnxruntime>=1.16.3\n", "einops>=0.7.0\n", "onnx2torch>=1.5.13\n", "warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git\n", "torchtools @ git+https://github.com/pabloppp/pytorch-tools\n", "diffusers" ] }, { "cell_type": "code", "execution_count": null, "id": "5802c2fa-f333-4137-ac1e-1f902c2ad382", "metadata": {}, "outputs": [], "source": [ "%%writefile code/inference.py\n", "import base64\n", "import json\n", "import os\n", "import torch\n", "\n", "from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline\n", "from io import BytesIO\n", "\n", "\n", "def model_fn(model_dir):\n", " \"\"\"\n", " Load the model for inference\n", " \"\"\"\n", " print(\"Entering model_fn...\")\n", " print(f\"Model Directory is {model_dir}\")\n", "\n", " prior = StableCascadePriorPipeline.from_pretrained(f\"{model_dir}/prior\", local_files_only=True)\n", " decoder = StableCascadeDecoderPipeline.from_pretrained(\n", " f\"{model_dir}/decoder\", local_files_only=True\n", " )\n", "\n", " model_dict = {\"prior\": prior, \"decoder\": decoder}\n", " print(f\"model dictionary: {model_dict}\")\n", " return model_dict\n", "\n", "\n", "def predict_fn(input_data, model_dict):\n", " \"\"\"\n", " Apply model to the incoming request\n", " \"\"\"\n", " print(\"Entering predict_fn...\")\n", " prior = model_dict[\"prior\"]\n", " decoder = model_dict[\"decoder\"]\n", "\n", " print(f\"Processing input_data {input_data}\")\n", " prompt = input_data[\"prompt\"]\n", " negative_prompt = input_data[\"negative_prompt\"]\n", " print(f\"Prompt = {prompt}\")\n", " print(f\"Negative Prompt = {negative_prompt}\")\n", "\n", " prior.enable_model_cpu_offload()\n", " prior_output = prior(\n", " prompt=prompt,\n", " height=1024,\n", " width=1024,\n", " negative_prompt=negative_prompt,\n", " guidance_scale=4.0,\n", " num_images_per_prompt=1,\n", " num_inference_steps=20,\n", " )\n", "\n", " decoder.enable_model_cpu_offload()\n", " decoder_output = decoder(\n", " image_embeddings=prior_output.image_embeddings,\n", " prompt=prompt,\n", " negative_prompt=negative_prompt,\n", " guidance_scale=0.0,\n", " output_type=\"pil\",\n", " num_inference_steps=10,\n", " ).images[0]\n", "\n", " encoded_images = []\n", " buffered = BytesIO()\n", " decoder_output.save(buffered, format=\"JPEG\")\n", " encoded_images.append(base64.b64encode(buffered.getvalue()).decode())\n", " print(\"Finished encodeing returned images.\")\n", " return {\"generated_images\": encoded_images}" ] }, { "cell_type": "code", "execution_count": null, "id": "a9c1554b-71e0-4755-9661-c0df74fc91ce", "metadata": {}, "outputs": [], "source": [ "# assemble model package\n", "model_tar = Path(f\"model-{random.getrandbits(16)}\")\n", "model_tar.mkdir(exist_ok=True)\n", "\n", "copy_tree(prior_path, str(model_tar.joinpath(\"prior\")))\n", "copy_tree(decoder_path, str(model_tar.joinpath(\"decoder\")))\n", "copy_tree(code_path, str(model_tar.joinpath(\"code\")))" ] }, { "cell_type": "code", "execution_count": null, "id": "b42f83a4-e0c8-46ed-b477-3fe3d617e365", "metadata": {}, "outputs": [], "source": [ "# helper to create the model.tar.gz\n", "def compress(tar_dir=None, output_file=\"model.tar.gz\"):\n", " parent_dir = os.getcwd()\n", " os.chdir(tar_dir)\n", " with tarfile.open(os.path.join(parent_dir, output_file), \"w:gz\") as tar:\n", " for item in os.listdir(\".\"):\n", " print(item)\n", " tar.add(item, arcname=item)\n", " os.chdir(parent_dir)\n", "\n", "\n", "compress(str(model_tar))" ] }, { "cell_type": "code", "execution_count": null, "id": "08218d56-f750-4414-83a7-db36068234ec", "metadata": {}, "outputs": [], "source": [ "# upload model.tar.gz to s3\n", "s3_model_uri = S3Uploader.upload(\n", " local_path=\"model.tar.gz\",\n", " desired_s3_uri=f\"s3://{sess.default_bucket()}/stable-cascade\",\n", ")\n", "print(f\"model uploaded to: {s3_model_uri}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "cc684183-8da4-4eed-9e6a-e30c2bc5471b", "metadata": {}, "outputs": [], "source": [ "%cd" ] }, { "cell_type": "code", "execution_count": null, "id": "4b40ece5-67bf-4bc6-987b-7306310f9d6c", "metadata": {}, "outputs": [], "source": [ "# helper decoder\n", "def decode_base64_image(image_string):\n", " base64_image = base64.b64decode(image_string)\n", " buffer = BytesIO(base64_image)\n", " return Image.open(buffer)\n", "\n", "\n", "# display PIL images as grid\n", "def display_images(images=None, columns=3, width=100, height=100):\n", " plt.figure(figsize=(width, height))\n", " for i, image in enumerate(images):\n", " plt.subplot(int(len(images) / columns + 1), columns, i + 1)\n", " plt.axis(\"off\")\n", " plt.imshow(image)" ] }, { "cell_type": "code", "execution_count": null, "id": "9228d7a7-497d-40ff-aecd-1b135813e186", "metadata": {}, "outputs": [], "source": [ "# create Hugging Face Model Class\n", "huggingface_model = HuggingFaceModel(\n", " model_data=s3_model_uri,\n", " role=get_execution_role(sess),\n", " transformers_version=\"4.17\",\n", " pytorch_version=\"1.10\",\n", " py_version=\"py38\",\n", ")\n", "\n", "# deploy the endpoint endpoint\n", "predictor = huggingface_model.deploy(initial_instance_count=1, instance_type=\"ml.g5.48xlarge\")" ] }, { "cell_type": "code", "execution_count": null, "id": "25d06a2e-12b8-4774-83d1-9a014be56faf", "metadata": {}, "outputs": [], "source": [ "start_time = time.time()\n", "\n", "# invoke_endpoint_async API call\n", "client = boto3.client(\"sagemaker-runtime\")\n", "prompt = \"A dog trying to catch a flying pizza art\"\n", "num_images_per_prompt = 1\n", "payload = {\"prompt\": prompt, \"negative_prompt\": \"\"}\n", "\n", "serialized_payload = json.dumps(payload, indent=4)\n", "with open(\"payload.json\", \"w\") as outfile:\n", " outfile.write(serialized_payload)\n", "\n", "response = client.invoke_endpoint(\n", " EndpointName=predictor.endpoint_name,\n", " ContentType=\"application/json\", # Specify the format of the payload\n", " Accept=\"application/json\",\n", " Body=serialized_payload,\n", ")\n", "print(f\"inference response: {response}\")\n", "\n", "response_payload = json.loads(response[\"Body\"].read().decode(\"utf-8\"))\n", "\n", "# decode images\n", "decoded_images = [decode_base64_image(image) for image in response_payload[\"generated_images\"]]\n", "\n", "# visualize generation\n", "display_images(decoded_images)\n", "\n", "end_time = time.time()\n", "inference_time = end_time - start_time\n", "print(f\"total inference time = {inference_time}\")" ] }, { "cell_type": "markdown", "id": "123ac143-dfaf-4944-8b57-e1c043153d59", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "\n", "![ This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n", "\n", "![ This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }